import numpy as np
import random
import math
from sklearn.linear_model import LinearRegression

#Instructions: run run*** first, then run plot***
'''
n: number of individuals
value: the value of the interference paths etc
iters: for calculation of the expectation
'''
def Simulate(experimentNum, n, dRate, rRate, alpha, value, iters, totalDef = -1, totalRef = -1):
    #structure: deflect: Yj<-Ci->Xi, reflect: Xi->Mj->Yi
    myResult = []
    traditional = []
    traditionalWithCond = []
    numSelected = []
    ReID = [] # the i-th set is empty if no reflect, otherwise it includes the deflect cause
    YMapDe = []
    XMapDe = []


    for i in range(n):
        ReID += [set()]
        YMapDe += [set()]
        XMapDe += [set()]
    
    # Add deflect
    countDef = 0
    if totalDef == -1: # passed in rate
        totalDef = int(dRate * n * (n - 1))
    else: # passed in number
        dRate = totalDef/(n*(n-1))
    
    enough = False
    while not enough:
        for i in range(n):
            for j in range(n):
                if i != j:
                    if countDef >= totalDef:
                        enough = True
                        break
                    coin = random.random() < dRate
                    if coin:
                        XMapDe[i].add(j)
                        YMapDe[j].add(i)
                        countDef += 1
            if enough:
                break
    # Add reflect
    countRef = 0
    if totalRef == -1:
        totalRef = int(rRate * n)
    else:
        rRate = totalRef/n
    enough = False
    while not enough:
        for i in range(n):
            if countRef >= totalRef:
                enough = True
                break

            coin = random.random() < rRate
            if coin:
                j = i
                while j == i:
                    j = random.choice(range(n))
                ReID[i].add(j)
                countRef += 1
    # start constructing subset
    maxSet = set()
    for algoit in range(10):
        indices = list(range(n))
        random.shuffle(indices)
        selectedId = set()
        for i in indices:
            if len(XMapDe[i].intersection(selectedId))==0 and len(YMapDe[i].intersection(selectedId))==0 and len(ReID[i])==0:
                selectedId.add(i)
        if len(selectedId) > len(maxSet):
            maxSet = selectedId.copy()

    selectedIdList = list(maxSet)
    my = []
    for it in range(iters):
        C = np.random.normal(0, 1, n)
        X = math.sqrt(value) * C + np.random.normal(0, 1, n)
        Y = alpha * X + np.random.normal(0, 1, n)
        M = math.sqrt(value) * X + np.random.normal(0, 1, n)
        # Add de
        for yid in range(n):
            for xid in YMapDe[yid]:
                Y[yid] += math.sqrt(value)*C[xid]
        
        # Add re
        for xyid in range(n):
            for mid in ReID[xyid]:
                M[mid] += -math.sqrt(value)*X[xyid]
        
        for xyid in range(n):
            for mid in ReID[xyid]:
                Y[xyid] += math.sqrt(value)*M[mid]


        reg = LinearRegression().fit(X.reshape(-1, 1), Y)
        traditional += [reg.coef_[0]]
        if len(selectedIdList)==0:
            my += [0]
        else:
            reg2 = LinearRegression().fit(X[selectedIdList].reshape(-1, 1), Y[selectedIdList])
            my += [reg2.coef_[0]]
    if experimentNum == 1:
        f = open("additionalResult_n_new.txt", "a")
    
        resultString = str(n) +" "+ str(len(selectedId)) +" "+ str(dRate) +" "+ str(rRate) +" "+ str(value) +" "+ str(np.mean(traditional)) +" "+str(np.mean(my)) +" "+ str(0) + '\n'
        f.write(resultString)
        f.close()
    if experimentNum == 2:
        f = open("additionalResult_value_new.txt", "a")
        resultString = str(n) +" "+ str(len(selectedId)) +" "+ str(dRate) +" "+ str(rRate) +" "+ str(value) +" "+ str(np.mean(traditional)) +" "+str(np.mean(my)) +" "+ str(0) + '\n'
        f.write(resultString)
        f.close()



def SimulateIncludingThm2(n, dRate, rRate, alpha, value, iters, totalDef = -1, totalRef = -1):
    #structure: deflect: Yj<-Ci->Xi, reflect: Xi->Mj->Yi
    myResult = []
    traditional = []
    thm2 = []
    traditionalWithCond = []
    numSelected = []
    ReID = [] # the i-th set is empty if no reflect, otherwise it includes the deflect cause
    YMapDe = []
    XMapDe = []


    for i in range(n):
        ReID += [set()]
        YMapDe += [set()]
        XMapDe += [set()]
    
    # Add deflect
    countDef = 0
    if totalDef == -1: # passed in rate
        totalDef = int(dRate * n * (n - 1))
    else: # passed in number
        dRate = totalDef/(n*(n-1))
    
    enough = False
    while not enough:
        for i in range(n):
            for j in range(n):
                if i != j:
                    if countDef >= totalDef:
                        enough = True
                        break
                    coin = random.random() < dRate
                    if coin:
                        XMapDe[i].add(j)
                        YMapDe[j].add(i)
                        countDef += 1
            if enough:
                break
    # Add reflect
    countRef = 0
    if totalRef == -1:
        totalRef = int(rRate * n)
    else:
        rRate = totalRef/n
    enough = False
    while not enough:
        for i in range(n):
            if countRef >= totalRef:
                enough = True
                break

            coin = random.random() < rRate
            if coin:
                j = i
                while j == i:
                    j = random.choice(range(n))
                ReID[i].add(j)
                countRef += 1
    # start constructing subset
    maxSet = set()
    for algoit in range(10):
        indices = list(range(n))
        random.shuffle(indices)
        selectedId = set()
        for i in indices:
            if len(XMapDe[i].intersection(selectedId))==0 and len(YMapDe[i].intersection(selectedId))==0 and len(ReID[i])==0:
                selectedId.add(i)
        if len(selectedId) > len(maxSet):
            maxSet = selectedId.copy()

    selectedIdList = list(maxSet)

    f = open("additionalResult_size_vs_density.txt", "a")
    resultString = str(n) +" "+ str(len(selectedId)) +" "+ str(dRate) +" "+ str(rRate) +'\n'
    f.write(resultString)
    f.close()


def CaseStudy(uMapAffectedBy, gMapAffectedBy, selectedIdList, n):
    myResult = []
    traditional = []
    thm2 = []
    traditionalWithCond = []
    numSelected = []
    

    T = np.random.normal(5, 2, n)
    for tid in range(n):
        if T[tid] < 0:
            T[tid] = 0
    U = 2 * T + np.random.normal(0, 1, n)
    G = U + np.random.normal(0, 1, n)

    # Add u-u edges
    for uid in range(n):
        for helper in uMapAffectedBy[uid]:
            U[uid] += 1*U[helper]
   
    # Add u-g edges
    for gid in range(n):
        for uid in gMapAffectedBy[gid]:
            G[gid] +=1*U[uid]
     
    reg = LinearRegression().fit(T.reshape(-1, 1), G)

    reg2 = LinearRegression().fit(T[selectedIdList].reshape(-1, 1), G[selectedIdList])
    return len(selectedIdList), reg.coef_[0], reg2.coef_[0]

def Wrapper():
    ReID = [] # the i-th set is empty if no reflect, otherwise it includes the deflect cause
    uMapAffectedBy = []
    uMapAffects = []
    gMapAffectedBy = []


    n=500
    alpha = 2
    value = 3

    for i in range(n):
        ReID += [set()]
        uMapAffectedBy += [set()]
        uMapAffects += [set()]
        gMapAffectedBy += [set()]

    
    # Add deflect
    countDef = 0
    totalDef = 500*5
    dRate = totalDef/(n*(n-1))*2
    countRef = 0
    
    enough = False
    while not enough:
        for i in range(n):
            for j in range(n):
                if i < j:
                    if countDef >= totalDef:
                        enough = True
                        break
                    coin = random.random() < dRate
                    if coin:
                        uMapAffectedBy[j].add(i) # ui-->uj
                        uMapAffects[i].add(j)
                        countDef += 1
                        coin2 = random.random() < 0.5
                        if coin2:
                            countRef += 1
                            gMapAffectedBy[i].add(j) # uj-->gi
            if enough:
                break

   # start constructing subset
    maxSet = set()
    for algoit in range(10):
        indices = list(range(n))
        random.shuffle(indices)
        selectedId = set()
        for i in indices:
            if len(XMapDe[i].intersection(selectedId))==0 and len(YMapDe[i].intersection(selectedId))==0 and len(ReID[i])==0:
                selectedId.add(i)
        if len(selectedId) > len(maxSet):
            maxSet = selectedId.copy()

    selectedIdList = list(maxSet)

    avgLength = 0
    avgReg1 = 0
    avgReg2 = 0
    val = []
    regResult = []
    for iters in range(1000):
        length, reg1, reg2 = CaseStudy(uMapAffectedBy, gMapAffectedBy, selectedIdList, n)
        avgLength += length/1000.0
        avgReg1 += reg1/1000.0
        avgReg2 += reg2/1000.0
        val += [reg2]
        regResult += [reg1]
        f = open("case_study_data.txt", "a")
        resultString = str(reg1) +" "+ str(reg2) +" "+ str(length)+'\n'
        f.write(resultString)
        f.close()
    
    print(avgLength, avgReg1, avgReg2)













def runExperiment1():
    iters = 10000
    value = 100 #the value of the bias paths
    alpha = 100 #the value of the direct effect
    for n in range(100, 10000, 100):
        Simulate(1, n, 0, 0, alpha, value, iters, 100, 100)

def runExperiment2():
    iters = 10000
    value = 100 #the value of the bias paths
    alpha = 100 #the value of the direct effect
 
    n = 1000
    for value in range(1000, -10, -10):
        Simulate(2, n, 0, 0, alpha, value, iters, 100, 100)

def runSubsetExperiment():
    iters = 10000
    value = 100 #the value of the bias paths
    alpha = 100 #the value of the direct effect
 
    n = 10000
    for dRate in range(0, 50, 2):
        for rRate in range(0, 50, 2):
            SimulateIncludingThm2(n, dRate/100.0, rRate/100.0, alpha, value, iters)

def runCaseStudy():
    Wrapper()

def plotExperiment1():
    f = open("additionalResult_n_new.txt", "r")
    fileList = f.readlines()
    X = []
    Y = []
    Ymy = []
    throwAway = 0
    for line in fileList:
        if throwAway != 0:
            temp = line.split()
            X += [int(temp[0])]
            Y += [float(temp[5])]
            Ymy += [float(temp[6])]
        throwAway = 1
    print(fileList[0])
    temp = fileList[0].split()
    print(temp)
    
    plt.plot(X, Y)
    plt.plot(X, Ymy)
    loc, labels = plt.yticks()
    plt.yticks(np.arange(int(min(loc)), int(max(loc)), step = 10))
    loc2, labels2 = plt.xticks()
    plt.xticks(np.arange(0, 10001, step = 1000))
    plt.plot([0, 10000], [100, 100], '--')
    plt.xlabel("n")
    plt.ylabel(chr(946)+"_YX")
    plt.legend(["Result from REG", "Result from THM-2", "TACE"], loc=0, frameon=True)
     
    plt.show()

def plotExperiment2():
    f = open("additionalResult_value_new.txt", "r")
    fileList = f.readlines()
    X = []
    Y = []
    Ymy = []
    for line in fileList:
        temp = line.split()
        X += [int(temp[4])]
        Y += [float(temp[5])]
        Ymy += [float(temp[6])]
    print(fileList[0])
    temp = fileList[0].split()
    print(temp)
    
    plt.plot(X, Y)
    plt.plot(X, Ymy)
    loc, labels = plt.yticks()
    plt.yticks(np.arange(int(min(loc)), int(max(loc)), step = 10))
    loc2, labels2 = plt.xticks()
    plt.xticks(np.arange(1000, -1, step = -100))
    plt.plot([0, 1000], [100, 100], '--')
    plt.xlabel("value")
    plt.ylabel(chr(946)+'_YX')
    plt.xlim(max(X)+50, min(X)-50)
    plt.legend(["Result from REG", "Result from THM-2", "TACE"], loc=0, frameon=True)
     
    plt.show()

def plotCaseStudyThm2():
    f = open("case_study_data.txt", "r")
    fileList = f.readlines()
    thm2 = []
    reg = []
    length = []
    for line in fileList:
        temp = line.split()
        thm2 += [float(temp[1])]
        reg += [float(temp[0])]
        length += [int(temp[2])]
    print(np.mean(reg))
    plt.hist(thm2, bins=10)
    plt.xlabel(chr(946))
    plt.ylabel('number of occurences')
    plt.show()

def plotCaseStudyREG():
    f = open("case_study_data.txt", "r")
    fileList = f.readlines()
    thm2 = []
    reg = []
    length = []
    for line in fileList:
        temp = line.split()
        thm2 += [float(temp[1])]
        reg += [float(temp[0])]
        length += [int(temp[2])]
    print(np.mean(reg))
    plt.hist(reg, bins=10)
    plt.xlabel(chr(946))
    plt.ylabel('number of occurences')
    plt.show()

